{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import boto3\n",
    "import sagemaker\n",
    "import pandas as pd\n",
    "\n",
    "sess   = sagemaker.Session()\n",
    "bucket = sess.default_bucket()\n",
    "role = sagemaker.get_execution_role()\n",
    "region = boto3.Session().region_name\n",
    "\n",
    "sm = boto3.Session().client(service_name='sagemaker', region_name=region)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Specify the S3 Location of the Features"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%store -r scikit_processing_job_s3_output_prefix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "print('Previous Scikit Processing Job Name: {}'.format(scikit_processing_job_s3_output_prefix))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prefix_train = '{}/output/bert-train'.format(scikit_processing_job_s3_output_prefix)\n",
    "prefix_validation = '{}/output/bert-validation'.format(scikit_processing_job_s3_output_prefix)\n",
    "prefix_test = '{}/output/bert-test'.format(scikit_processing_job_s3_output_prefix)\n",
    "\n",
    "path_train = './{}'.format(prefix_train)\n",
    "path_validation = './{}'.format(prefix_validation)\n",
    "path_test = './{}'.format(prefix_test)\n",
    "\n",
    "train_s3_uri = 's3://{}/{}'.format(bucket, prefix_train)\n",
    "validation_s3_uri = 's3://{}/{}'.format(bucket, prefix_validation)\n",
    "test_s3_uri = 's3://{}/{}'.format(bucket, prefix_test)\n",
    "\n",
    "s3_input_train_data = sagemaker.s3_input(s3_data=train_s3_uri, content_type='text/csv')\n",
    "s3_input_validation_data = sagemaker.s3_input(s3_data=validation_s3_uri, content_type='text/csv')\n",
    "s3_input_test_data = sagemaker.s3_input(s3_data=test_s3_uri, content_type='text/csv')\n",
    "\n",
    "print(s3_input_train_data.config)\n",
    "print(s3_input_validation_data.config)\n",
    "print(s3_input_test_data.config)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!cat src_bert/bert_reviews.py"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from sagemaker.pytorch import PyTorch\n",
    "\n",
    "model_output_path = 's3://{}/models/bert/script-mode/training-runs'.format(bucket)\n",
    "\n",
    "bert_estimator = PyTorch(entry_point='bert_reviews.py',\n",
    "                         source_dir='src_bert',\n",
    "                         role=role,\n",
    "                         train_instance_count=1, # 1 is actually faster due to communication overhead with >1\n",
    "                         train_instance_type='ml.p3.8xlarge',\n",
    "                         py_version='py3',\n",
    "                         framework_version='1.4.0',\n",
    "                         output_path=model_output_path,\n",
    "                         hyperparameters={'model_type':'bert',\n",
    "                                          'model_name': 'bert-base-cased',\n",
    "                                          'backend': 'gloo'},\n",
    "                         enable_cloudwatch_metrics=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Train the model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bert_estimator.fit(inputs={'train': s3_input_train_data, \n",
    "                           'validation': s3_input_validation_data,}, \n",
    "                   wait=False) "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "training_job_name = bert_estimator.latest_training_job.name\n",
    "print('training_job_name:  {}'.format(training_job_name))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/sagemaker/home?region={}#/jobs/{}\">Training Job</a> After About 5 Minutes</b>'.format(region, training_job_name)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://console.aws.amazon.com/cloudwatch/home?region={}#logStream:group=/aws/sagemaker/TrainingJobs;prefix={};streamFilter=typeLogStreamPrefix\">CloudWatch Logs</a> After About 5 Minutes</b>'.format(region, training_job_name)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.core.display import display, HTML\n",
    "\n",
    "training_job_s3_output_prefix = 'models/t2-bert/{}'.format(training_job_name) # 'models/tf-bert/script-mode/training-runs/{}'.format(training_job_name)\n",
    "\n",
    "display(HTML('<b>Review <a href=\"https://s3.console.aws.amazon.com/s3/buckets/{}/{}/?region={}&tab=overview\">S3 Output Data</a> After The Training Job Has Completed</b>'.format(bucket, training_job_s3_output_prefix, region)))\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "scrolled": true
   },
   "outputs": [],
   "source": [
    "from sagemaker.pytorch import PyTorch\n",
    "\n",
    "bert_estimator = PyTorch.attach(training_job_name=training_job_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load the Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# download the model artifact from AWS S3\n",
    "!aws s3 cp $model_output_path/$training_job_name/output/model.tar.gz ./models/bert/"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tarfile\n",
    "import pickle as pkl\n",
    "\n",
    "tar = tarfile.open('./models/bert/model.tar.gz')\n",
    "tar.extractall(path='./models/bert-model')\n",
    "tar.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from simpletransformers.classification import ClassificationModel\n",
    "\n",
    "args = {\n",
    "   'fp16': False,\n",
    "   'max_seq_length': 128\n",
    "}\n",
    "\n",
    "bert_model = ClassificationModel(model_type='distilbert', # bert, distilbert, etc, etc.\n",
    "                                 model_name='./models/bert-model',\n",
    "                                 args=args,\n",
    "                                 use_cuda=False)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Predict \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions, raw_outputs = bert_model.predict([\"\"\"I really enjoyed this item.  I highly recommend it.\"\"\"])\n",
    "\n",
    "print('Predictions: {}'.format(predictions))\n",
    "print('Raw outputs: {}'.format(raw_outputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "predictions, raw_outputs = bert_model.predict([\"\"\"This item is awful and terrible.\"\"\"])\n",
    "\n",
    "print('Predictions: {}'.format(predictions))\n",
    "print('Raw outputs: {}'.format(raw_outputs))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "conda_python3",
   "language": "python",
   "name": "conda_python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
